/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveUnaryRule.SubRule;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees.PennTreeRenderer;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.Pair;

/**
 * @author petrov
 *
 */
public class HierarchicalAdaptiveBinaryRule extends HierarchicalBinaryRule {
	private static final long serialVersionUID = 1L;

	public short[][][] mapping;
	Tree<Double> hierarchy;
	public int nParam;
	public SubRule[] subRuleList;

	

	// assume for now that the rule being passed in is unsplit
	public HierarchicalAdaptiveBinaryRule(BinaryRule b) {
		super(b);
		hierarchy = new Tree<Double>(0.0);
		scores = new double[1][1][1];
		mapping = new short[1][1][1]; //to parameters
		nParam = 1;
	}
	
	public Pair<Integer,Integer> countParameters(){
		// first one is the max_depth, second one is the number of parameters
		int maxDepth = hierarchy.getDepth();
		nParam = hierarchy.getYield().size();
		return new Pair<Integer,Integer>(maxDepth, nParam);
	}
	
	public HierarchicalAdaptiveBinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode){
		splitRuleHelper(hierarchy, random, 8);
//		mapping = new short[newNumSubStates[this.leftChildState]][newNumSubStates[this.rightChildState]][newNumSubStates[this.parentState]];
//		int finalLevel = (int)(Math.log(mapping.length)/Math.log(2));
//		updateMapping((short)0, 0, 0, 0, 0, finalLevel, hierarchy);
		return this;
	}
	

//	private short updateMapping(short myID, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, int finalDepth, Tree<Double> tree) {
//		if (tree.isLeaf()){
//			if (myDepth==finalDepth){
//				mapping[nextLeftSubstate][nextRightSubstate][nextParentSubstate] = myID;
//			} else {
//				int substatesToCover = (int)Math.pow(2,finalDepth-myDepth);
//				nextLeftSubstate *= substatesToCover;
//				nextRightSubstate *= substatesToCover;
//				nextParentSubstate *= substatesToCover;
//				for (int i=0; i<substatesToCover; i++){
//					for (int j=0; j<substatesToCover; j++){
//						for (int k=0; k<substatesToCover; k++){
//							mapping[nextLeftSubstate+i][nextRightSubstate+j][nextParentSubstate+k] = myID;
//						}
//					}
//				}
//			}
//			myID++;
//		} else {
//			int i = 0;
//			for (Tree<Double> child : tree.getChildren()){
//				myID = updateMapping(myID, nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, finalDepth, child);
//				i++;
//			}
//		}
//		return myID;
//	}

	private void splitRuleHelper(Tree<Double> tree, Random random, int splitFactor) {
		if (tree.isLeaf()){
			if (tree.getLabel()!=0||nParam==1){ // split it
				ArrayList<Tree<Double>> children = new ArrayList<Tree<Double>>(splitFactor);
				for (int i=0; i<splitFactor; i++){
					Tree<Double> child = new Tree<Double>(random.nextDouble()/100.0);
					children.add(child);
				}
				tree.setChildren(children);
				nParam += splitFactor-1;
//			} else { //perturb it
//				tree.setLabel(random.nextDouble()/100.0);
			}
		} else {
			for (Tree<Double> child : tree.getChildren()){
				splitRuleHelper(child, random, splitFactor);
			}
		}
	}

	public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){
//		int nSubstates = (int)Math.pow(2, finalLevel);
//		scores = new double[nSubstates][nSubstates][nSubstates];
//		int nextSubstate = fillScores((short)0, 0, 0, 0, 0, 0, finalLevel, hierarchy);
//		if (nextSubstate != nParam) 
//			System.out.println("Didn't fill all scores!");
		computeSubRuleList();
	}
	
//	private short fillScores(short myID, double previousScore, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, int finalDepth, Tree<Double> tree){
//		if (tree.isLeaf()){
//			double myScore = Math.exp(previousScore + tree.getLabel());
//			if (myDepth==finalDepth){
//				scores[nextLeftSubstate][nextRightSubstate][nextParentSubstate] = myScore;
//			} else {
//				int substatesToCover = (int)Math.pow(2,finalDepth-myDepth);
//				nextLeftSubstate *= substatesToCover;
//				nextRightSubstate *= substatesToCover;
//				nextParentSubstate *= substatesToCover;
//				for (int i=0; i<substatesToCover; i++){
//					for (int j=0; j<substatesToCover; j++){
//						for (int k=0; k<substatesToCover; k++){
//							scores[nextLeftSubstate+i][nextRightSubstate+j][nextParentSubstate+k] = myScore;
//						}
//					}
//				}
//			}
//			myID++;
//		} else {
//			double myScore = previousScore + tree.getLabel();
//			int i = 0;
//			for (Tree<Double> child : tree.getChildren()){
//				myID = fillScores(myID, myScore,  nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, finalDepth, child);
//				i++;
//			}
//		}
//		return myID;
//	}
	
	public void updateScores(double[] scores){
		int nSubstates = updateHierarchy(hierarchy, 0, scores);
		if (nSubstates != nParam) System.out.println("Didn't update all parameters");
//		if (subRuleList!=null){
//		int i = 0;
//		for (SubRule r : subRuleList){
//			r.score = scores[this.identifier + i++];
//		}
//		}
	}

	
	private int updateHierarchy(Tree<Double> tree, int nextSubstate, double[] scores) {
		if (tree.isLeaf()){
			double val = scores[identifier + nextSubstate++];
			if (val>200) {
				val = 0;
				System.out.println("Ignored proposed binary value since it was danegrous");
			} else 
			  tree.setLabel(val);
		} else {
			for (Tree<Double> child : tree.getChildren()){
				nextSubstate = updateHierarchy(child, nextSubstate, scores);
			}
		}
		return nextSubstate;
	}

	
	public int mergeRule() { 
		int paramBefore = nParam;
		compactifyHierarchy(hierarchy);
		scores = null;
		mapping = null;
		subRuleList = null;
		scoreHierarchy = null;
		return paramBefore - nParam; 
	}

	
	/**
	 * @return
	 */
	public List<Double> getFinalLevel() {
		return hierarchy.getYield();
	}
	
	private void compactifyHierarchy(Tree<Double> tree){
		if (tree.getDepth()==2){
			boolean allZero = true;
			for (Tree<Double> child : tree.getChildren()){
				allZero = allZero && child.getLabel()==0;
			}
			if (allZero) {
				nParam -= tree.getChildren().size()-1;
				tree.setChildren(Collections.EMPTY_LIST);
			}
		} else {
			for (Tree<Double> child : tree.getChildren()){
				compactifyHierarchy(child);
			}
		}
	}
	
	public String toStringShort(){
    Numberer n = Numberer.getGlobalNumberer("tags");
    String lState = (String)n.object(leftChildState);
    String rState = (String)n.object(rightChildState);
    String pState = (String)n.object(parentState);
    return (pState+" -> "+lState+" "+rState);
	}
	
	public String toString(){
		StringBuilder sb = new StringBuilder();
    Numberer n = Numberer.getGlobalNumberer("tags");
    String lState = (String)n.object(leftChildState);
    String rState = (String)n.object(rightChildState);
    String pState = (String)n.object(parentState);
    sb.append(pState+" -> "+lState+" "+rState+"\n");
		if (subRuleList==null){
			compactifyHierarchy(hierarchy);
			lastLevel = hierarchy.getDepth();
			computeSubRuleList();
		}
		
		for (SubRule rule : subRuleList){
			sb.append(rule.toString(lastLevel-1));
			sb.append("\n");
		}
		

//		sb.append(PennTreeRenderer.render(hierarchy));
		sb.append("\n");
//		sb.append(Arrays.toString(scores));
		return sb.toString();
	}

	public int countNonZeroFeatures() {
		int total = 0;
		for (Tree<Double> d : hierarchy.getPreOrderTraversal()) { if (d.getLabel()!=0) total++; }
		return total;
	}
	
	public int countNonZeroFringeFeatures() {
		int total = 0;
		for (Tree<Double> d : hierarchy.getTerminals()) { if (d.getLabel()!=0) total++; }
		return total;
	}
	
	public void computeSubRuleList(){
		subRuleList = new SubRule[nParam];
		int nRules = computeSubRules(0, 0, 0, 0, 0, 0, hierarchy);
		if (nRules != nParam)
			System.out.println("A rule got lost");
	}
	
	private int computeSubRules(int myID, double previousScore, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, Tree<Double> tree){
		if (tree.isLeaf()){
			double myScore = Math.exp(previousScore + tree.getLabel());
			SubRule rule = new SubRule((short)nextLeftSubstate, (short)nextRightSubstate, (short)nextParentSubstate, (short)myDepth, myScore);
			subRuleList[myID]=rule;
			myID++;
		} else {
			double myScore = previousScore + tree.getLabel();
			int i = 0;
			for (Tree<Double> child : tree.getChildren()){
				myID = computeSubRules(myID, myScore,  nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, child);
				i++;
			}
		}
		return myID;
	}	
	class SubRule implements Serializable{
		private static final long serialVersionUID = 1L;
		short lChild, rChild, parent, level;
		double score;
		
		SubRule(short lC, short rC, short p, short l, double s){
			lChild = lC;
			rChild = rC;
			parent = p;
			level = l;
			score = s;
		}
		
		public String toString(){
			String s =  "["+parent+"] \t -> \t ["+lChild+"] \t ["+rChild+"] \t "+score;
			return s;
		}
		
		public String toString(int finalLevel){
			if (finalLevel==level) return toString();
			int k = (int)Math.pow(2, finalLevel-level);
			String s = "["+(k*parent)+"-"+(k*parent+k-1)+"] \t -> \t ["+(k*lChild)+"-"+(k*lChild+k-1)+"] \t ["+(k*rChild)+"-"+(k*rChild+k-1)+"] \t "+score+"\t level: "+level;
			return s;
		}
		
	}
}
